import taichi as ti

from .common import Vec3


@ti.data_oriented
class GaussBloom:
    def __init__(self, size):
        self.tmp = Vec3.field(shape=size)
        self.weight = ti.field(float, 128)
        self.radius = ti.field(int, 1)
        self.initWeight(min(*size) // 16)

    @ti.kernel
    def initWeight(self, radius: int):
        sum = 0.0
        self.radius[0] = radius
        for i in range(radius + 1):
            self.weight[i] = ti.exp(-((i / (radius / 3)) ** 2) / 2)
            sum += self.weight[i] * (1 + (i > 0))
        for i in range(radius + 1):
            self.weight[i] /= sum

    @ti.func
    def filter(self, rgb):
        # return max(rgb - 1, 0)
        return 1 - 1 / max(1, rgb)

    @ti.kernel
    def produce(self, src: ti.template(), out: ti.template()):
        for x, y in self.tmp:
            self.tmp[x, y] = self.filter(src[x, y]) * self.weight[0]
            for i in range(1, self.radius[0] + 1):
                self.tmp[x, y] += (
                    self.filter(src[max(0, x - i), y])
                    + self.filter(src[min(self.tmp.shape[0] - 1, x + i), y])
                ) * self.weight[i]
        for x, y in self.tmp:
            out[x, y] += self.tmp[x, y] * self.weight[0]
            for i in range(1, self.radius[0] + 1):
                val = self.tmp[x, max(0, y - i)]
                val += self.tmp[x, min(self.tmp.shape[1] - 1, y + i)]
                out[x, y] += val * self.weight[i]
